{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhoQ0WE77laV"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "_ckMIh7O7s6D"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jYysdyb-CaWM"
      },
      "source": [
        "# tf.distribute.Strategy を使用したカスタムトレーニング"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Uhzt6vVIB2"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "<img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/custom_training.ipynb\">TensorFlow.org で表示</a>\n",
        "</td>\n",
        "  <td>\n",
        "<img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\"><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/tutorials/distribute/custom_training.ipynb\">Google Colab で実行</a>\n",
        "</td>\n",
        "  <td>\n",
        "<img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\"><a target=\"_blank\" href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/tutorials/distribute/custom_training.ipynb\">GitHub でソースを表示</a>\n",
        "</td>\n",
        "  <td>\n",
        "<img src=\"https://www.tensorflow.org/images/download_logo_32px.png\"><a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/custom_training.ipynb\">ノートブックをダウンロード</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FbVhjPpzn6BM"
      },
      "source": [
        "このチュートリアルでは、[`tf.distribute.Strategy`](https://www.tensorflow.org/guide/distributed_training)をカスタムトレーニングループで使用する方法を示します。Fashion MNIST データセットで単純な CNN モデルをトレーニングします。Fashion MNIST データセットには、サイズ 28×28 のトレーニング画像 60,000 枚とサイズ 28×28 のテスト画像 10,000 枚が含まれています。\n",
        "\n",
        "モデルのトレーニングにカスタムトレーニングループを使用するのは、柔軟性があり、トレーニングを容易に制御できるからです。それに加えて、モデルとトレーニングループのデバッグも容易になります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dzLKpmZICaWN"
      },
      "outputs": [],
      "source": [
        "# Import TensorFlow\n",
        "import tensorflow as tf\n",
        "\n",
        "# Helper libraries\n",
        "import numpy as np\n",
        "import os\n",
        "\n",
        "print(tf.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MM6W__qraV55"
      },
      "source": [
        "## Fashion MNIST データセットをダウンロードする"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7MqDQO0KCaWS"
      },
      "outputs": [],
      "source": [
        "fashion_mnist = tf.keras.datasets.fashion_mnist\n",
        "\n",
        "(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()\n",
        "\n",
        "# Adding a dimension to the array -&gt; new shape == (28, 28, 1)\n",
        "# We are doing this because the first layer in our model is a convolutional\n",
        "# layer and it requires a 4D input (batch_size, height, width, channels).\n",
        "# batch_size dimension will be added later on.\n",
        "train_images = train_images[..., None]\n",
        "test_images = test_images[..., None]\n",
        "\n",
        "# Getting the images in [0, 1] range.\n",
        "train_images = train_images / np.float32(255)\n",
        "test_images = test_images / np.float32(255)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4AXoHhrsbdF3"
      },
      "source": [
        "## 変数とグラフを分散させるストラテジーを作成する"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5mVuLZhbem8d"
      },
      "source": [
        "`tf.distribute.MirroredStrategy`ストラテジーはどのように機能するのでしょう？\n",
        "\n",
        "- 全ての変数とモデルグラフはレプリカ上に複製されます。\n",
        "- 入力はレプリカ全体に均等に分散されます。\n",
        "- 各レプリカは受け取った入力の損失と勾配を計算します。\n",
        "- 勾配は加算して全てのレプリカ間で同期されます。\n",
        "- 同期後、各レプリカ上の変数のコピーにも同じ更新が行われます。\n",
        "\n",
        "注意: 下記のコードは全て 1 つのスコープ内に入れることができます。説明しやすいように、ここでは幾つかのコードセルに分割しています。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F2VeZUWUj5S4"
      },
      "outputs": [],
      "source": [
        "# If the list of devices is not specified in the\n",
        "# `tf.distribute.MirroredStrategy` constructor, it will be auto-detected.\n",
        "strategy = tf.distribute.MirroredStrategy()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZngeM_2o0_JO"
      },
      "outputs": [],
      "source": [
        "print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k53F5I_IiGyI"
      },
      "source": [
        "## 入力パイプラインをセットアップする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Qb6nDgxiN_n"
      },
      "source": [
        "グラフと変数をプラットフォームに依存しない SavedModel 形式にエクスポートします。モデルが保存された後、スコープの有無に関わらずそれを読み込むことができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jwJtsCQhHK-E"
      },
      "outputs": [],
      "source": [
        "BUFFER_SIZE = len(train_images)\n",
        "\n",
        "BATCH_SIZE_PER_REPLICA = 64\n",
        "GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync\n",
        "\n",
        "EPOCHS = 10"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J7fj3GskHC8g"
      },
      "source": [
        "データセットを作成して、それを配布します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WYrMNNDhAvVl"
      },
      "outputs": [],
      "source": [
        "train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) \n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) \n",
        "\n",
        "train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)\n",
        "test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bAXAo_wWbWSb"
      },
      "source": [
        "## モデルを作成する\n",
        "\n",
        "`tf.keras.Sequential`を使用してモデルを作成します。これは Model Subclassing API を使用しても作成できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ODch-OFCaW4"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "  model = tf.keras.Sequential([\n",
        "      tf.keras.layers.Conv2D(32, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Conv2D(64, 3, activation='relu'),\n",
        "      tf.keras.layers.MaxPooling2D(),\n",
        "      tf.keras.layers.Flatten(),\n",
        "      tf.keras.layers.Dense(64, activation='relu'),\n",
        "      tf.keras.layers.Dense(10)\n",
        "    ])\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9iagoTBfijUz"
      },
      "outputs": [],
      "source": [
        "# Create a checkpoint directory to store the checkpoints.\n",
        "checkpoint_dir = './training_checkpoints'\n",
        "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e-wlFFZbP33n"
      },
      "source": [
        "## 損失関数を定義する\n",
        "\n",
        "通常、GPU/CPU を 1 つ搭載した単一のマシンでは、損失は入力バッチ内の例の数で除算されます。\n",
        "\n",
        "*では、`tf.distribute.Strategy`を使用する場合、どのように損失を計算すればよいのでしょうか？*\n",
        "\n",
        "- 例えば、4 つの GPU と 64 のバッチサイズがあるとします。1 つの入力バッチは（4 つの GPU の）レプリカに分散されるので、各レプリカはサイズ 16 の入力を取得します。\n",
        "\n",
        "- 各レプリカのモデルは、それぞれの入力でフォワードパスを実行し、損失を計算します。ここでは、損失をそれぞれの入力の例の数（BATCH_SIZE_PER_REPLICA = 16）で除算するのではなく、損失を GLOBAL_BATCH_SIZE (64) で除算する必要があります。\n",
        "\n",
        "*なぜそうするのでしょう？*\n",
        "\n",
        "- 勾配を各レプリカで計算した後にそれらを**加算**してレプリカ間で同期するため、これを行う必要があります。\n",
        "\n",
        "*TensorFlow でこれを行うには？*\n",
        "\n",
        "- このチュートリアルにもあるように、カスタムトレーニングループを書く場合は、例ごとの損失を加算し、その合計を GLOBAL_BATCH_SIZE: `scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)`で除算する必要があります。または、`tf.nn.compute_average_loss`を使用することも可能です。これは例ごとの損失、オプションのサンプルの重み、そしてGLOBAL_BATCH_SIZE を引数として取り、スケーリングされた損失を返します。\n",
        "\n",
        "- モデルで正則化損失を使用している場合は、損失値をレプリカの数でスケーリングする必要があります。これを行うには、`tf.nn.scale_regularization_loss`関数を使用します。\n",
        "\n",
        "- `tf.reduce_mean`の使用は推奨されません。これを使用すると、損失がレプリカごとの実際のバッチサイズで除算され、ステップごとに変化する場合があります。\n",
        "\n",
        "- この縮小とスケーリングは、keras`model.compile`と`model.fit`で自動的に行われます。\n",
        "\n",
        "- 以下の例のように`tf.keras.losses`クラスを使用する場合、損失削減は`NONE`または`SUM`のいずれかになるよう、明示的に指定する必要があります。`AUTO`および`SUM_OVER_BATCH_SIZE`の`tf.distribute.Strategy`との併用は許可されません。`AUTO`は、分散型のケースでユーザーがどの削減を正しいと確認するか明示的に考える必要があるため、許可されていません。`SUM_OVER_BATCH_SIZE`は、現時点ではレプリカのバッチサイズのみで除算され、レプリカの数に基づく除算はユーザーに任されており見落としがちなため、許可されていません。そのため、その代わりにユーザー自身が明示的に削減を行うようにお願いしています。\n",
        "\n",
        "- もし`labels`が多次元である場合は、各サンプルの要素数全体で`per_example_loss`を平均化します。例えば、`predictions`の形状が`(batch_size, H, W, n_classes)`で、`labels`が`(batch_size, H, W)`の場合、`per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)`のように`per_example_loss`を更新する必要があります。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R144Wci782ix"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  # Set reduction to `none` so we can do the reduction afterwards and divide by\n",
        "  # global batch size.\n",
        "  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
        "      from_logits=True,\n",
        "      reduction=tf.keras.losses.Reduction.NONE)\n",
        "  def compute_loss(labels, predictions):\n",
        "    per_example_loss = loss_object(labels, predictions)\n",
        "    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w8y54-o9T2Ni"
      },
      "source": [
        "## 損失と精度を追跡するメトリクスを定義する\n",
        "\n",
        "これらのメトリクスは、テストの損失、トレーニング、テストの精度を追跡します。`.result()`を使用して、いつでも累積統計を取得できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zt3AHb46Tr3w"
      },
      "outputs": [],
      "source": [
        "with strategy.scope():\n",
        "  test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
        "\n",
        "  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='train_accuracy')\n",
        "  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='test_accuracy')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iuKuNXPORfqJ"
      },
      "source": [
        "## トレーニングループ"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrMmakq5EqeQ"
      },
      "outputs": [],
      "source": [
        "# model, optimizer, and checkpoint must be created under `strategy.scope`.\n",
        "with strategy.scope():\n",
        "  model = create_model()\n",
        "\n",
        "  optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3UX43wUu04EL"
      },
      "outputs": [],
      "source": [
        "def train_step(inputs):\n",
        "  images, labels = inputs\n",
        "\n",
        "  with tf.GradientTape() as tape:\n",
        "    predictions = model(images, training=True)\n",
        "    loss = compute_loss(labels, predictions)\n",
        "\n",
        "  gradients = tape.gradient(loss, model.trainable_variables)\n",
        "  optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
        "\n",
        "  train_accuracy.update_state(labels, predictions)\n",
        "  return loss \n",
        "\n",
        "def test_step(inputs):\n",
        "  images, labels = inputs\n",
        "\n",
        "  predictions = model(images, training=False)\n",
        "  t_loss = loss_object(labels, predictions)\n",
        "\n",
        "  test_loss.update_state(t_loss)\n",
        "  test_accuracy.update_state(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gX975dMSNw0e"
      },
      "outputs": [],
      "source": [
        "# `run` replicates the provided computation and runs it\n",
        "# with the distributed input.\n",
        "@tf.function\n",
        "def distributed_train_step(dataset_inputs):\n",
        "  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))\n",
        "  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,\n",
        "                         axis=None)\n",
        "\n",
        "@tf.function\n",
        "def distributed_test_step(dataset_inputs):\n",
        "  return strategy.run(test_step, args=(dataset_inputs,))\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  # TRAIN LOOP\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  for x in train_dist_dataset:\n",
        "    total_loss += distributed_train_step(x)\n",
        "    num_batches += 1\n",
        "  train_loss = total_loss / num_batches\n",
        "\n",
        "  # TEST LOOP\n",
        "  for x in test_dist_dataset:\n",
        "    distributed_test_step(x)\n",
        "\n",
        "  if epoch % 2 == 0:\n",
        "    checkpoint.save(checkpoint_prefix)\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, \"\n",
        "              \"Test Accuracy: {}\")\n",
        "  print (template.format(epoch+1, train_loss,\n",
        "                         train_accuracy.result()*100, test_loss.result(),\n",
        "                         test_accuracy.result()*100))\n",
        "\n",
        "  test_loss.reset_states()\n",
        "  train_accuracy.reset_states()\n",
        "  test_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z1YvXqOpwy08"
      },
      "source": [
        "上記の例における注意点:\n",
        "\n",
        "-  for文（`for x in ...`）を使用して、`train_dist_dataset`と`test_dist_dataset`に対してイテレーションしています。\n",
        "- スケーリングされた損失は `distributed_train_step`の戻り値です。この値は`tf.distribute.Strategy.reduce`呼び出しを使用してレプリカ間で集約され、次に`tf.distribute.Strategy.reduce`呼び出しの戻り値を加算してバッチ間で集約されます。\n",
        "- `tf.keras.Metrics`は、`tf.distribute.Strategy.run`によって実行される`train_step`と`test_step`内で更新される必要があります。*`tf.distribute.Strategy.run`はストラテジー内の各ローカルレプリカの結果を返し、この結果の消費方法は多様です。`tf.distribute.Strategy.reduce`を実行して、集約された値を取得することができます。また、`tf.distribute.Strategy.experimental_local_results`を実行して、ローカルレプリカごとに 1 つ、結果に含まれる値のリストを取得することもできます。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-q5qp31IQD8t"
      },
      "source": [
        "## 最新のチェックポイントを復元してテストする"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WNW2P00bkMGJ"
      },
      "source": [
        "`tf.distribute.Strategy`でチェックポイントされたモデルは、ストラテジーの有無に関わらず復元することができます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pg3B-Cw_cn3a"
      },
      "outputs": [],
      "source": [
        "eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(\n",
        "      name='eval_accuracy')\n",
        "\n",
        "new_model = create_model()\n",
        "new_optimizer = tf.keras.optimizers.Adam()\n",
        "\n",
        "test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7qYii7KUYiSM"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def eval_step(images, labels):\n",
        "  predictions = new_model(images, training=False)\n",
        "  eval_accuracy(labels, predictions)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LeZ6eeWRoUNq"
      },
      "outputs": [],
      "source": [
        "checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)\n",
        "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))\n",
        "\n",
        "for images, labels in test_dataset:\n",
        "  eval_step(images, labels)\n",
        "\n",
        "print ('Accuracy after restoring the saved model without strategy: {}'.format(\n",
        "    eval_accuracy.result()*100))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EbcI87EEzhzg"
      },
      "source": [
        "## データセットのイテレーションの代替方法\n",
        "\n",
        "### イテレータを使用する\n",
        "\n",
        "データセット全体ではなく、任意のステップ数のイテレーションを行いたい場合は、`iter`呼び出しを使用してイテレータを作成し、そのイテレータ上で`next`を明示的に呼び出すことができます。tf.function の内側と外側の両方でデータセットのイテレーションを選択することができます。ここでは、イテレータを使用し tf.function 外側のデータセットのイテレーションを実行する小さなスニペットを示します。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7c73wGC00CzN"
      },
      "outputs": [],
      "source": [
        "for _ in range(EPOCHS):\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  train_iter = iter(train_dist_dataset)\n",
        "\n",
        "  for _ in range(10):\n",
        "    total_loss += distributed_train_step(next(train_iter))\n",
        "    num_batches += 1\n",
        "  average_train_loss = total_loss / num_batches\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "  print (template.format(epoch+1, average_train_loss, train_accuracy.result()*100))\n",
        "  train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxVp48Oy0m6y"
      },
      "source": [
        "### tf.function 内でイテレーションする\n",
        "\n",
        "tf.function の内側で for 文（`for x in ...`）を使用して、あるいは上記で行ったようにイテレータを作成して、入力`train_dist_dataset`全体をイテレーションすることもできます。次の例では、トレーニングの 1 つのエポックを tf.function でラップし、関数内で`train_dist_dataset`をイテレーションする方法を示します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-REzmcXv00qm"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def distributed_train_epoch(dataset):\n",
        "  total_loss = 0.0\n",
        "  num_batches = 0\n",
        "  for x in dataset:\n",
        "    per_replica_losses = strategy.run(train_step, args=(x,))\n",
        "    total_loss += strategy.reduce(\n",
        "      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)\n",
        "    num_batches += 1\n",
        "  return total_loss / tf.cast(num_batches, dtype=tf.float32)\n",
        "\n",
        "for epoch in range(EPOCHS):\n",
        "  train_loss = distributed_train_epoch(train_dist_dataset)\n",
        "\n",
        "  template = (\"Epoch {}, Loss: {}, Accuracy: {}\")\n",
        "  print (template.format(epoch+1, train_loss, train_accuracy.result()*100))\n",
        "\n",
        "  train_accuracy.reset_states()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MuZGXiyC7ABR"
      },
      "source": [
        "### レプリカ間でトレーニング損失を追跡する\n",
        "\n",
        "注意: 一般的なルールとして、サンプルごとの値の追跡には`tf.keras.Metrics`を使用し、レプリカ内で集約された値を避ける必要があります。\n",
        "\n",
        "`tf.metrics.Mean`を使用すると損失スケーリングが計算されるため、異なるレプリカ間のトレーニング損失の追跡には*推奨できません*。\n",
        "\n",
        "例えば、次のような特徴を持つトレーニングジョブを実行するとします。\n",
        "\n",
        "- レプリカ 2 つ\n",
        "- 各レプリカで 2 つのサンプルを処理\n",
        "- 結果の損失値 : 各レプリカで [2,  3] および [4,  5]\n",
        "- グローバルバッチサイズ = 4\n",
        "\n",
        "損失スケーリングで損失値を加算して各レプリカのサンプルごとの損失の値を計算し、さらにグローバルバッチサイズで除算します。この場合は、`(2 + 3) / 4 = 1.25`および`(4 + 5) / 4 = 2.25`となります。\n",
        "\n",
        "`tf.metrics.Mean`を使用して 2 つのレプリカ間の損失を追跡する場合、結果は異なります。この例では、`total`は 3.50、`count`は 2 となり、メトリック上で`result()`が呼び出された場合の結果は`total`/`count` = 1.75 となります。`tf.keras.Metrics`で計算された損失は、同期するレプリカの数に等しい追加の係数によってスケーリングされます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xisYJaV9KZTN"
      },
      "source": [
        "### ガイドと例\n",
        "\n",
        "カスタムトレーニングループを用いた分散ストラテジーの使用例をここに幾つか示します。\n",
        "\n",
        "1. [分散型トレーニングガイド](../../guide/distributed_training)\n",
        "2. `MirroredStrategy`を使用した [DenseNet](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/densenet/distributed_train.py) の例。\n",
        "3. `MirroredStrategy`と`TPUStrategy`を使用してトレーニングされた [BERT](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py) の例。この例は、分散トレーニングなどの間にチェックポイントから読み込む方法と、定期的にチェックポイントを生成する方法を理解するのに特に有用です。\n",
        "4. `MirroredStrategy`を使用してトレーニングされ、`keras_use_ctl`フラグを使用した有効化が可能な、[NCF](https://github.com/tensorflow/models/blob/master/official/recommendation/ncf_keras_main.py)の例。\n",
        "5. `MirroredStrategy`を使用してトレーニングされた、[NMT](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/nmt_with_attention/distributed_train.py) の例。\n",
        "\n",
        "[分散ストラテジーガイド](../../guide/distributed_training.ipynb#examples_and_tutorials)には他の例も記載されています。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6hEJNsokjOKs"
      },
      "source": [
        "## 次のステップ\n",
        "\n",
        "- 新しい`tf.distribute.Strategy` API を独自のモデルで試してみましょう。\n",
        "- 他のストラテジーや独自の TensorFlow モデルのパフォーマンス最適化に使用できる[ツール](../../guide/profiler.md)についての詳細は、ガイドの[パフォーマンスのセクション](../../guide/function.ipynb)をご覧ください。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "custom_training.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
